
#include "statsxx/statistics/covariance_functions/Polynomial.hpp"

// STL
#include <memory>    // std::unique_ptr<>
#include <vector>    // std::vector<>

// jScience
#include "jScience/linalg.hpp" // Vector<>, dot_product()


inline Polynomial::Polynomial(
                              const int di,
                              const std::vector<double> &th
                              ) : CovarianceFunction(th)
{
    this->d = di;
}

inline Polynomial::~Polynomial() {};


inline std::unique_ptr<CovarianceFunction> Polynomial::clone() const
{
    return std::unique_ptr<CovarianceFunction>(new Polynomial(*this));
}


//
// DESC: Evaluation of the covariance function.
//
// NOTE: See: https://en.wikipedia.org/wiki/Polynomial_kernel
//
inline double Polynomial::cov(const Vector<double> &x1, const Vector<double> &x2)
{
    double xTy_a_p_c = dot_product(x1, x2)/this->theta[0] + this->theta[1];

    return std::pow(xTy_a_p_c, this->d);
}


//
// DESC: Implementation of dk/dtheta.
//
// NOTE: Calculated with: http://www.wolframalpha.com/
//
inline std::vector<double> Polynomial::ddtheta(const Vector<double> &x1, const Vector<double> &x2)
{
    // the two versions below (one commented out) treat the variables differently (e.g., sigma vs sigma^2)

    std::vector<double> derivs(2);

    double xTy = dot_product(x1, x2);
    double xTy_a_p_c = xTy/this->theta[0] + this->theta[1];

    derivs[0] = -xTy*(this->d)*std::pow(xTy_a_p_c, (this->d-1))/(this->theta[0]*this->theta[0]);

    derivs[1] = (this->d)*std::pow(xTy_a_p_c, (this->d-1));

    return derivs;
}


//
// DESC: Implementation of dk/dx.
//
// NOTE: Calculated with: http://www.wolframalpha.com/
//
inline Vector<double> Polynomial::ddx(const Vector<double> &x, const Vector<double> &x2)
{
    Vector<double> derivs(x.size());

    double xTy_a_p_c = dot_product(x, x2)/this->theta[0] + this->theta[1];
    double xTy_a_p_c_dm1_a = std::pow(xTy_a_p_c, (this->d-1))/this->theta[0];

    for( auto i = 0; i < x.size(); ++i )
    {
        derivs(i) = x2(i)*(this->d)*xTy_a_p_c_dm1_a;
    }

    return derivs;
}
